Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable tuning in Batch norm CK solver #3326

Merged
merged 13 commits into from
Oct 24, 2024
Merged

Conversation

bghimireamd
Copy link
Contributor

This PR only enabling tuning in CK solver.

src/include/miopen/batchnorm/problem_description.hpp Outdated Show resolved Hide resolved
src/include/miopen/batchnorm/problem_description.hpp Outdated Show resolved Hide resolved
src/include/miopen/batchnorm/problem_description.hpp Outdated Show resolved Hide resolved
src/include/miopen/batchnorm/problem_description.hpp Outdated Show resolved Hide resolved
src/ocl/batchnormocl.cpp Outdated Show resolved Hide resolved
src/solver/batchnorm/forward_inference_ck.cpp Outdated Show resolved Hide resolved
src/solver/batchnorm/forward_training_ck.cpp Outdated Show resolved Hide resolved
src/solver/batchnorm/forward_training_ck.cpp Outdated Show resolved Hide resolved
src/solver/batchnorm/forward_training_ck.cpp Outdated Show resolved Hide resolved
src/solver/batchnorm/forward_training_ck.cpp Outdated Show resolved Hide resolved
@DrizztDoUrden
Copy link
Contributor

IsApplicable/GetSolution in general should throw instead of asserting for several reasons:

  1. Asserts are compiled out when building in release mode. As a consequence, such code can segfault or abort otherwise if assert contains some precondition of the following code. For example:
assert(i < vec.size());
auto x = vec[i];

With i >= vec.size() in debug it would produce a meaningful error, while release build would just segfault at the second line.
2. Uses of solver methods are wrapped in try-catch blocks, which allows library to skip a faulty problem-solver pair and try next one. Both assert and segfault would terminate entire program.
Other than that, ideally, we want perf config to only contain a single case rather than every possible one and an index. They are copied around and in general interpreted as values in GenericSearch. But if the amount of cases is not great, it shouldn't be a problem.

Copy link
Contributor

@DrizztDoUrden DrizztDoUrden left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@@ -240,7 +240,8 @@ MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const ExecutionContex

fmha_runtime_args.p_drop = probability;
fmha_runtime_args.drop_seed_offset =
std::make_pair(dataFwd.dropoutSeedData, dataFwd.dropoutOffsetData);
std::make_pair(reinterpret_cast<uint64_t>(dataFwd.dropoutSeedData),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CAHEK7 had to this to silence the compiler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to double check - are you going to convert gpu pointer to integer? Is the code using fmha_runtime_args.drop_seed_offset aware that it's actually gpu pointer?
For me that reinterpret_cast doesn't look like a correct solution.

Copy link
Contributor Author

@bghimireamd bghimireamd Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! its a GPU pointer but I think for now we have disabled the dropout https://github.com/ROCm/MIOpen/blob/develop/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp#L234

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to make it explicitly 0 and put the same comment as

// TODO : Change API to take in probability value as host side value instead of device
// pointer to match CK API. Calling a blocking hipMemcpy will cause issues with stream,
// and isn't async.

Copy link
Contributor

@CAHEK7 CAHEK7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks some improvements can be done here. But it will work even without any changes.

The only thing which really concerns me is reinterpret_cast in src/solver/mha/mha_ck_fa_v2_solver_forward.cpp - it doesn't look like a correct code.

src/include/miopen/batchnorm/problem_description.hpp Outdated Show resolved Hide resolved
src/solver/batchnorm/backward_ck.cpp Outdated Show resolved Hide resolved
src/solver/batchnorm/backward_ck.cpp Outdated Show resolved Hide resolved
src/solver/batchnorm/backward_ck.cpp Outdated Show resolved Hide resolved
src/solver/batchnorm/forward_inference_ck.cpp Outdated Show resolved Hide resolved
src/solver/batchnorm/backward_ck.cpp Outdated Show resolved Hide resolved
src/solver/batchnorm/backward_ck.cpp Outdated Show resolved Hide resolved
src/solver/batchnorm/backward_ck.cpp Outdated Show resolved Hide resolved
src/include/miopen/batchnorm/solvers.hpp Show resolved Hide resolved
@junliume
Copy link
Collaborator

@bghimireamd Windows builds are failing, the error messages look like something caused by this PR, could you take a look?

@CAHEK7 CAHEK7 mentioned this pull request Oct 24, 2024
@junliume junliume merged commit 2d1fd99 into develop Oct 24, 2024
32 of 144 checks passed
@junliume junliume deleted the bg/lwpmiopen_759_bn_tuning branch October 24, 2024 17:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants